-
Notifications
You must be signed in to change notification settings - Fork 880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/more losses #845
Feat/more losses #845
Conversation
Codecov Report
@@ Coverage Diff @@
## master #845 +/- ##
==========================================
+ Coverage 91.40% 91.43% +0.03%
==========================================
Files 70 71 +1
Lines 7106 7135 +29
==========================================
+ Hits 6495 6524 +29
Misses 611 611
Continue to review full report at Codecov.
|
darts/utils/losses.py
Outdated
super().__init__() | ||
|
||
def forward(self, inpt, tgt): | ||
return torch.mean(torch.abs(inpt - tgt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this just the MAE? Or is this to overcome some of the issues with MAPE?
return torch.mean(torch.abs(inpt - tgt)) | |
return torch.mean(torch.abs(_divide_no_nan(inpt - tgt, inpt))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. Initially I ignored the denominator because it's impacting only the magnitude of the gradients, and it was giving somewhat better results, but it's not quite correct.
I have change it and also added MAE to the list now (unit test on its way) :)
darts/tests/utils/test_losses.py
Outdated
air_s = scaler.fit_transform(air) | ||
air_train, air_val = air_s[:-36], air_s[-36:] | ||
|
||
def test_smape_loss(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we check the actual output of the losses instead of fitting the models?
Just thinking about execution time, it takes a couple of seconds
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I thought of that as well. Although actually using the loss functions for fitting might reveal some problems that we wouldn't notice otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it's a tiny bit better to test the fitting to make sure the gradients are kept where they should, so we can leave it like that for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thoughts, I think your idea is better, as long as we're also checking the loss gradients. I've changed the tests to do that now, thanks for the suggestion 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition! 👍
darts/tests/utils/test_losses.py
Outdated
air_s = scaler.fit_transform(air) | ||
air_train, air_val = air_s[:-36], air_s[-36:] | ||
|
||
def test_smape_loss(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I thought of that as well. Although actually using the loss functions for fitting might reveal some problems that we wouldn't notice otherwise.
darts/utils/losses.py
Outdated
super().__init__() | ||
|
||
def forward(self, inpt, tgt): | ||
return torch.mean(torch.abs(inpt - tgt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, thanks a lot!
After addressing the last suggestions, it can be merged
|
||
def helper_test_loss(self, exp_loss_val, exp_w_grad, loss_fn): | ||
W = torch.tensor([[0.1, -0.2, 0.3, -0.4], [-0.8, 0.7, -0.6, 0.5]]) | ||
W.requires_grad = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice tests +1
darts/tests/utils/test_losses.py
Outdated
lval = loss_fn(y_hat, self.y) | ||
lval.backward() | ||
|
||
print(lval) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be removed
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
Add two new PyTorch loss functions (
SmapeLoss
andMapeLoss
), which can provide different criteria and could for instance be used to replicate some of the M3/M4 competition results.